import torch
import os 
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, 
                       help='the model which needs to transfer from gpu to cpu')
parser.add_argument('--version', type=int, default= '1',
                       help='the model which needs to transfer from gpu to cpu')

args = parser.parse_args()
if 'dalle' in args.model_name:
    loaded_obj = torch.load(os.path.join('./outputs/dalle_models', args.model_name))
else:
    loaded_obj = torch.load(os.path.join('./outputs/vae_models', args.model_name))
    
if args.version == 1:
    hparams, weights = loaded_obj['hparams'], loaded_obj['weights']
    keys = list(weights.keys())
    for k in keys:
        name = k
        w = weights[k].cpu().clone()
        weights.pop(k)
        weights[name] = w
    save_obj = {
            'hparams': hparams,
            'weights': weights
    }
else:
    hparams, weights, epoch, opt_state, sche_state = loaded_obj['hparams'], loaded_obj['weights'], loaded_obj['epoch'], loaded_obj.get('opt_state'), loaded_obj.get('scheduler_state')
    keys = list(weights.keys())
    for k in keys:
        name = k
        w = weights[k].cpu().clone()
        weights.pop(k)
        weights[name] = w
    save_obj = {
            'hparams': hparams,
            'weights': weights,
            'epoch': epoch,
            'opt_state': opt_state,
            'scheduler_state': sche_state,
    }

if 'dalle' in args.model_name:
    torch.save(save_obj, os.path.join('./outputs/dalle_models', args.model_name.split('.')[0]+'-cpu.pt'))
else:
    torch.save(save_obj, os.path.join('./outputs/vae_models', args.model_name.split('.')[0]+'-cpu.pt'))
